{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02111.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02111.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61f2MtXYLLps",
        "colab_type": "text"
      },
      "source": [
        "## 10.12 机器翻译"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cukcwvQaQNWJ",
        "colab_type": "text"
      },
      "source": [
        "### 10.12.1 读取和预处理数据"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eUhGwLNlLkbE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import collections \n",
        "import os \n",
        "import io \n",
        "import math \n",
        "import torch \n",
        "from torch import nn \n",
        "import torch.nn.functional as F \n",
        "import torchtext.vocab as Vocab \n",
        "import torch.utils.data as Data \n",
        "import d2l \n",
        "\n",
        "PAD, BOS, EOS = '<pad>', '<bos>', '<eos>'\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yLx1Y5-RPGRp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def process_one_seq(seq_tokens, all_tokens, all_seqs, max_seq_len):\n",
        "    all_tokens.extend(seq_tokens)\n",
        "    seq_tokens += [EOS] + [PAD] * (max_seq_len - len(seq_tokens) - 1)\n",
        "    all_seqs.append(seq_tokens)\n",
        "\n",
        "def build_data(all_tokens, all_seqs):\n",
        "    vocab = Vocab.Vocab(collections.Counter(all_tokens), specials=[PAD, BOS, EOS])\n",
        "    indices = [[vocab.stoi[w] for w in seq] for seq in all_seqs]\n",
        "    return vocab, torch.tensor(indices)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HKjtMwdWQUW5",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 134
        },
        "outputId": "60d0d87f-dcc3-4d68-8ba2-6726d5c02015"
      },
      "source": [
        "!git clone https://github.com/d2l-ai/d2l-zh.git"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'd2l-zh'...\n",
            "remote: Enumerating objects: 5, done.\u001b[K\n",
            "remote: Counting objects: 100% (5/5), done.\u001b[K\n",
            "remote: Compressing objects: 100% (5/5), done.\u001b[K\n",
            "remote: Total 15713 (delta 0), reused 1 (delta 0), pack-reused 15708\u001b[K\n",
            "Receiving objects: 100% (15713/15713), 159.54 MiB | 54.46 MiB/s, done.\n",
            "Resolving deltas: 100% (11138/11138), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "atmdEKFqP9Vi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def read_data(max_seq_len):\n",
        "    in_tokens, out_tokens, in_seqs, out_seqs = [], [], [], []\n",
        "    with io.open('d2l-zh/data/fr-en-small.txt') as f:\n",
        "        lines = f.readlines()\n",
        "    for line in lines:\n",
        "        in_seq, out_seq = line.rstrip().split('\\t')\n",
        "        in_seq_tokens, out_seq_tokens = in_seq.split(' '), out_seq.split(' ')\n",
        "        if max(len(in_seq_tokens), len(out_seq_tokens)) > max_seq_len - 1:\n",
        "            continue \n",
        "        process_one_seq(in_seq_tokens, in_tokens, in_seqs, max_seq_len)\n",
        "        process_one_seq(out_seq_tokens, out_tokens, out_seqs, max_seq_len)\n",
        "    in_vocab, in_data = build_data(in_tokens, in_seqs)\n",
        "    out_vocab, out_data = build_data(out_tokens, out_seqs)\n",
        "    return in_vocab, out_vocab, Data.TensorDataset(in_data, out_data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yONg6tapUccc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "2220df70-7aae-49bc-9400-bff106594b31"
      },
      "source": [
        "max_seq_len = 7 \n",
        "in_vocab, out_vocab, dataset = read_data(max_seq_len)\n",
        "dataset[0]"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(tensor([ 5,  4, 45,  3,  2,  0,  0]), tensor([ 8,  4, 27,  3,  2,  0,  0]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zf-4GTXGUoNJ",
        "colab_type": "text"
      },
      "source": [
        "### 10.12.2 含注意力机制的编码器-解码器"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jU-TxL1iUunW",
        "colab_type": "text"
      },
      "source": [
        "#### 1 编码器"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FZZr8UtKU1d_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Encoder(nn.Module):\n",
        "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, drop_prob=0, **kwargs):\n",
        "        super(Encoder, self).__init__(**kwargs)\n",
        "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
        "        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=drop_prob)\n",
        "\n",
        "    def forward(self, inputs, state):\n",
        "        embedding = self.embedding(inputs.long()).permute(1, 0, 2)\n",
        "        return self.rnn(embedding, state)\n",
        "\n",
        "    def begin_state(self):\n",
        "        return None "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uHXXJn9w0qZ0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "49e7788b-6d3c-4128-cec4-926cb99b4b67"
      },
      "source": [
        "encoder = Encoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)\n",
        "output, state = encoder(torch.zeros((4, 7)), encoder.begin_state())\n",
        "output.shape, state.shape"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OC0x51V728dI",
        "colab_type": "text"
      },
      "source": [
        "#### 2 注意力机制"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IzHrrMPa2_Zk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def attention_model(input_size, attention_size):\n",
        "    model = nn.Sequential(nn.Linear(input_size, attention_size, bias=False), \n",
        "                nn.Tanh(), \n",
        "                nn.Linear(attention_size, 1, bias=False))\n",
        "    return model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Vm2w2BSp1UuR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def attention_forward(model, enc_states, dec_state):\n",
        "    dec_states = dec_state.unsqueeze(dim=0).expand_as(enc_states)\n",
        "    enc_and_dec_states = torch.cat((enc_states, dec_states), dim=2)\n",
        "    e = model(enc_and_dec_states)\n",
        "    alpha = F.softmax(e, dim=0)\n",
        "    return (alpha * enc_states).sum(dim=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TStoBBqC15Hv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "6ca42838-75db-4f57-b5ce-9c5e9f22c8e5"
      },
      "source": [
        "seq_len, batch_size, num_hiddens = 10, 4, 8 \n",
        "model = attention_model(2 * num_hiddens, 10)\n",
        "enc_states = torch.zeros((seq_len, batch_size, num_hiddens))\n",
        "dec_state = torch.zeros((batch_size, num_hiddens))\n",
        "attention_forward(model, enc_states, dec_state).shape"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([4, 8])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D0tUSYjS2t8G",
        "colab_type": "text"
      },
      "source": [
        "#### 3 含注意力机制的解码器"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FYk42cZ-3f7j",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Decoder(nn.Module):\n",
        "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, attention_size, drop_prob=0):\n",
        "        super(Decoder, self).__init__()\n",
        "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
        "        self.attention = attention_model(2 * num_hiddens, attention_size)\n",
        "        self.rnn = nn.GRU(num_hiddens + embed_size, num_hiddens, num_layers, dropout=drop_prob)\n",
        "        self.out = nn.Linear(num_hiddens, vocab_size)\n",
        "\n",
        "    def forward(self, cur_input, state, enc_states):\n",
        "        c = attention_forward(self.attention, enc_states, state[-1])\n",
        "        input_and_c = torch.cat((self.embedding(cur_input), c), dim=1)\n",
        "        output, state = self.rnn(input_and_c.unsqueeze(0), state)\n",
        "        output = self.out(output).squeeze(dim=0)\n",
        "        return output, state \n",
        "\n",
        "    def begin_state(self, enc_state):\n",
        "        return enc_state"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y1QsJ_MI42an",
        "colab_type": "text"
      },
      "source": [
        "### 10.12.3 训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vFxNBCev48Qm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def batch_loss(encoder, decoder, X, Y, loss):\n",
        "    batch_size = X.shape[0]\n",
        "    enc_state = encoder.begin_state()\n",
        "    enc_outputs, enc_state = encoder(X, enc_state)\n",
        "    dec_state = decoder.begin_state(enc_state)\n",
        "    dec_input = torch.tensor([out_vocab.stoi[BOS]] * batch_size)\n",
        "    mask, num_not_pad_tokens = torch.ones(batch_size, ), 0 \n",
        "    l = torch.tensor([0.0])\n",
        "    for y in Y.permute(1, 0):\n",
        "        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)\n",
        "        l = l + (mask * loss(dec_output, y)).sum()\n",
        "        dec_input = y \n",
        "        num_not_pad_tokens += mask.sum().item()\n",
        "        mask = mask * (y != out_vocab.stoi[EOS]).float()\n",
        "    return l / num_not_pad_tokens"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Kmw_esvp6Mmb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train(encoder, decoder, dataset, lr, batch_size, num_epochs):\n",
        "    enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)\n",
        "    dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)\n",
        "    loss = nn.CrossEntropyLoss(reduction='none')\n",
        "    data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)\n",
        "    for epoch in range(num_epochs):\n",
        "        l_sum = 0.0 \n",
        "        for X, Y in data_iter:\n",
        "            enc_optimizer.zero_grad()\n",
        "            dec_optimizer.zero_grad()\n",
        "            l = batch_loss(encoder, decoder, X, Y, loss)\n",
        "            l.backward()\n",
        "            enc_optimizer.step()\n",
        "            dec_optimizer.step()\n",
        "            l_sum += l.item()\n",
        "        if (epoch + 1) % 10 == 0:\n",
        "            print(\"epoch %d, loss %.3f\" % (epoch + 1, l_sum / len(data_iter)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F_0dxuTx7hLj",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "58ce4c2b-c066-41f4-9dbd-7f77618e295d"
      },
      "source": [
        "embed_size, num_hiddens, num_layers = 64, 64, 2 \n",
        "attention_size, drop_prob, lr, batch_size, num_epochs = 10, 0.5, 0.01, 2, 50\n",
        "encoder = Encoder(len(in_vocab), embed_size, num_hiddens, num_layers, drop_prob)\n",
        "decoder = Decoder(len(out_vocab), embed_size, num_hiddens, num_layers, attention_size, drop_prob)\n",
        "train(encoder, decoder, dataset, lr, batch_size, num_epochs)"
      ],
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 10, loss 0.586\n",
            "epoch 20, loss 0.215\n",
            "epoch 30, loss 0.098\n",
            "epoch 40, loss 0.026\n",
            "epoch 50, loss 0.010\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z9neUt3L8QVQ",
        "colab_type": "text"
      },
      "source": [
        "### 10.12.4 预测不定长的序列"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y_xJOIvW9rm9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def translate(encoder, decoder, input_seq, max_seq_len):\n",
        "    in_tokens = input_seq.split(' ')\n",
        "    in_tokens += [EOS] + [PAD] * (max_seq_len - len(in_tokens) - 1)\n",
        "    enc_input = torch.tensor([[in_vocab.stoi[tk] for tk in in_tokens]])\n",
        "    enc_state = encoder.begin_state()\n",
        "    enc_output, enc_state = encoder(enc_input, enc_state)\n",
        "    dec_input = torch.tensor([out_vocab.stoi[BOS]])\n",
        "    dec_state = decoder.begin_state(enc_state)\n",
        "    output_tokens = []\n",
        "    for _ in range(max_seq_len):\n",
        "        dec_output, dec_state = decoder(dec_input, dec_state, enc_output)\n",
        "        pred = dec_output.argmax(dim=1)\n",
        "        pred_token = out_vocab.itos[int(pred.item())]\n",
        "        if pred_token == EOS:\n",
        "            break \n",
        "        else:\n",
        "            output_tokens.append(pred_token)\n",
        "            dec_input = pred \n",
        "    return output_tokens "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3Oc-MvFD_MtO",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "bdca0189-b095-4bc2-dabb-c6dc76380506"
      },
      "source": [
        "input_seq = 'ils regardent'\n",
        "translate(encoder, decoder, input_seq, max_seq_len)"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['they', 'are', 'watching', '.']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 33
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1dm5W_DT_WWo",
        "colab_type": "text"
      },
      "source": [
        "### 10.12.5 评价翻译结果"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kUevLAcaAEtP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def bleu(pred_tokens, label_tokens, k):\n",
        "    len_pred, len_label = len(pred_tokens), len(label_tokens)\n",
        "    score = math.exp(min(0, 1 - len_label / len_pred))\n",
        "    for n in range(1, k + 1):\n",
        "        num_matches, label_subs = 0, collections.defaultdict(int)\n",
        "        for i in range(len_label - n + 1):\n",
        "            label_subs[''.join(label_tokens[i: i + n])] += 1 \n",
        "        for i in range(len_pred - n + 1):\n",
        "            if label_subs[''.join(pred_tokens[i: i + n])] > 0:\n",
        "                num_matches += 1 \n",
        "                label_subs[''.join(pred_tokens[i: i + n])] -= 1 \n",
        "        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))\n",
        "    return score "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-sNrOWqZBaI6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def score(input_seq, label_seq, k):\n",
        "    pred_tokens = translate(encoder, decoder, input_seq, max_seq_len)\n",
        "    label_tokens = label_seq.split(' ')\n",
        "    print('bleu %.3f, predict: %s' % (bleu(pred_tokens, label_tokens, k), \n",
        "                        ' '.join(pred_tokens)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F4EUW8ehB6TP",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "e27bd8a6-f733-4d62-e2be-dfef8dcff2cd"
      },
      "source": [
        "score('ils regardent .', 'they are watching .', k=2)"
      ],
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "bleu 1.000, predict: they are watching .\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hjx5_fBDCGB7",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "bb79445d-0e89-4fb2-b57a-e2ccbabeb6f0"
      },
      "source": [
        "score('ils sont canadienne .', 'they are canadian .', k=2)"
      ],
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "bleu 0.658, predict: they are actors .\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}